[Relay][Frontend] Add support for aten::concat#16199
Conversation
|
@sweetcocoa Thank you for your PR! |
|
@mshr-h, I agree with your comments. |
|
Thank you @sweetcocoa for your PR! I agree the test is needed |
| b = (args[0][:, :, 1] + 3) * 11 | ||
| c = (args[0][:, :, 2] + 5) * 13 | ||
| return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2) | ||
|
|
There was a problem hiding this comment.
Please remove trailing whitespaces.
| @tvm.testing.uses_gpu | ||
| def test_simple_rnn(): | ||
| """test_simple_rnn""" | ||
|
|
There was a problem hiding this comment.
nit: I suppose that all such white spaces are redundant. Could you please remove them in this test and in the tests below?
There was a problem hiding this comment.
Sorry I didn't see it, I've reverted it now.
03ddc08 to
6e23667
Compare
|
I don't think this CI error stems from this PR, can I try restarting it? |
Yes, you can comment with |
|
@tvm-bot rerun |
| class Concatenate3(Module): | ||
| # pylint: disable=missing-class-docstring | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| class _Concatenate(Module): | ||
| def forward(self, *args): | ||
| a = (args[0][:, :, 0] + 2) * 7 | ||
| b = (args[0][:, :, 1] + 3) * 11 | ||
| c = (args[0][:, :, 2] + 5) * 13 | ||
| return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2) | ||
|
|
||
| self.mod = _Concatenate() | ||
|
|
||
| def forward(self, *args): | ||
| return self.mod(*args) | ||
|
|
There was a problem hiding this comment.
Why do you create a class in this way?
Will the following code work in the same way?
| class Concatenate3(Module): | |
| # pylint: disable=missing-class-docstring | |
| def __init__(self): | |
| super().__init__() | |
| class _Concatenate(Module): | |
| def forward(self, *args): | |
| a = (args[0][:, :, 0] + 2) * 7 | |
| b = (args[0][:, :, 1] + 3) * 11 | |
| c = (args[0][:, :, 2] + 5) * 13 | |
| return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2) | |
| self.mod = _Concatenate() | |
| def forward(self, *args): | |
| return self.mod(*args) | |
| class Concatenate3(Module): | |
| def forward(self, *args): | |
| a = (args[0][:, :, 0] + 2) * 7 | |
| b = (args[0][:, :, 1] + 3) * 11 | |
| c = (args[0][:, :, 2] + 5) * 13 | |
| return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2) |
There was a problem hiding this comment.
@echuraev
The torch.concat is preserved as aten::concat only when it is in a nested module like this code. (In the most cases, It is converted to aten::cat instead of aten::concat.) I've tried to find a reason for this, but haven't found one.
There was a problem hiding this comment.
Thank you for your reply! Could you please in this case specify it in the class description, instead of using # pylint: disable=missing-class-docstring?
echuraev
left a comment
There was a problem hiding this comment.
LGTM! Thank you for your PR!
I think it is a quite simple problem,
aten::concatis just an alias ofaten::cat, but it is not supported.https://github.com/pytorch/pytorch/blob/3fbfa8cd0a5cefadb3f116c5cd0d60e96ab8c99e/aten/src/ATen/native/TensorShape.cpp#L667
If needed, I will add a minimal example to reproduce.